{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SAC的简化版实现,alpha使用常量代替.只使用一个value模型,而不是两个."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAch0lEQVR4nO3df3BTdb438PdJ06Q/k9JCE3pppXOLYgfKSoES9Y7zLFmqdl2UurPrZbDDsjpiYEC8zNpdwVlnZ8oDM+vqrqIzOyvcP7Q77N3qyoJun4JFJBQoVMuvqivaLpBUwCaltOmPfJ4/pOcarZi0Tb5Jfb9mzow530+SzwHyNudHvkcTEQERUYwZVDdARN9NDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlJCWfg8//zzmD59OlJSUlBWVobDhw+raoWIFFASPn/+85+xfv16PPXUUzh27BjmzJmD8vJydHZ2qmiHiBTQVPywtKysDPPnz8cf/vAHAEAwGER+fj7WrFmDJ5544lufHwwGcf78eWRmZkLTtGi3S0RhEhF0d3cjLy8PBsP1v9sYY9STrr+/H83NzaiurtbXGQwGOJ1OuN3uEZ8TCAQQCAT0x+fOnUNxcXHUeyWi0eno6MC0adOuWxPz8Ll48SKGhoZgs9lC1ttsNpw5c2bE59TU1ODXv/7119Z3dHTAYrFEpU8iipzf70d+fj4yMzO/tTbm4TMa1dXVWL9+vf54eAMtFgvDhygOhXM4JObhM3nyZCQlJcHr9Yas93q9sNvtIz7HbDbDbDbHoj0iipGYn+0ymUwoLS1FQ0ODvi4YDKKhoQEOhyPW7RCRIkp2u9avX4+qqirMmzcPCxYswO9+9zv09PRgxYoVKtohIgWUhM9PfvITfPbZZ9i0aRM8Hg++973v4c033/zaQWgimriUXOczVn6/H1arFT6fjwecieJIJJ9N/raLiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSIuLw2b9/P+655x7k5eVB0zS89tprIeMigk2bNmHq1KlITU2F0+nEhx9+GFJz+fJlLFu2DBaLBVlZWVi5ciWuXLkypg0hosQScfj09PRgzpw5eP7550cc37JlC5577jm8+OKLaGpqQnp6OsrLy9HX16fXLFu2DCdPnkR9fT127dqF/fv34+GHHx79VhBR4pExACB1dXX642AwKHa7XbZu3aqv6+rqErPZLK+++qqIiJw6dUoAyJEjR/SaPXv2iKZpcu7cubDe1+fzCQDx+XxjaZ+Ixlkkn81xPeZz9uxZeDweOJ1OfZ3VakVZWRncbjcAwO12IysrC/PmzdNrnE4nDAYDmpqaRnzdQCAAv98fshBRYhvX8PF4PAAAm80Wst5ms+ljHo8Hubm5IeNGoxHZ2dl6zVfV1NTAarXqS35+/ni2TUQKJMTZrurqavh8Pn3p6OhQ3RIRjdG4ho/dbgcAeL3ekPVer1cfs9vt6OzsDBkfHBzE5cuX9ZqvMpvNsFgsIQsRJbZxDZ/CwkLY7XY0NDTo6/x+P5qamuBwOAAADocDXV1daG5u1mv27t2LYDCIsrKy8WyHiOKYMdInXLlyBR999JH++OzZs2hpaUF2djYKCgqwbt06/OY3v8GMGTNQWFiIjRs3Ii8vD/feey8A4Oabb8add96Jhx56CC+++CIGBgawevVq/PSnP0VeXt64bRgRxblIT6Xt27dPAHxtqaqqEpEvTrdv3LhRbDabmM1mWbRokbS1tYW8xqVLl+SBBx6QjIwMsVgssmLFCunu7g67B55qJ4pPkXw2NRERhdk3Kn6/H1arFT6fj8d/iOJIJJ/NhDjbRUQTD8OHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIRhU9NTQ3mz5+PzMxM5Obm4t5770VbW1tITV9fH1wuF3JycpCRkYHKykp4vd6Qmvb2dlRUVCAtLQ25ubnYsGEDBgcHx741RJQwIgqfxsZGuFwuHDp0CPX19RgYGMDixYvR09Oj1zz22GN44403sHPnTjQ2NuL8+fNYunSpPj40NISKigr09/fj4MGD2LFjB7Zv345NmzaN31YRUfyTMejs7BQA0tjYKCIiXV1dkpycLDt37tRrTp8+LQDE7XaLiMju3bvFYDCIx+PRa7Zt2yYWi0UCgUBY7+vz+QSA+Hy+sbRPROMsks/mmI75+Hw+AEB2djYAoLm5GQMDA3A6nXrNzJkzUVBQALfbDQBwu92YPXs2bDabXlNeXg6/34+TJ0+O+D6BQAB+vz9kIaLENurwCQaDWLduHW677TbMmjULAODxeGAymZCVlRVSa7PZ4PF49JovB8/w+PDYSGpqamC1WvUlPz9/tG0TUZwYdfi4XC6cOHECtbW149nPiKqrq+Hz+fSlo6Mj6u9JRNFlHM2TVq9ejV27dmH//v2YNm2avt5ut6O/vx9dXV0h3368Xi/sdrtec/jw4ZDXGz4bNlzzVWazGWazeTStElGciuibj4hg9erVqKurw969e1FYWBgyXlpaiuTkZDQ0NOjr2tra0N7eDofDAQBwOBxobW1FZ2enXlNfXw+LxYLi4uKxbAsRJZCIvvm4XC688soreP3115GZmakfo7FarUhNTYXVasXKlSuxfv16ZGdnw2KxYM2aNXA4HFi4cCEAYPHixSguLsby5cuxZcsWeDwePPnkk3C5XPx2Q/RdEslpNAAjLi+//LJe09vbK48++qhMmjRJ0tLS5L777pMLFy6EvM4nn3wid911l6SmpsrkyZPl8ccfl4GBgbD74Kl2ovgUyWdTExFRF32j4/f7YbVa4fP5YLFYVLdDRNdE8tnkb7uISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEiJUf2qnWg8SDCIoZ4eBPv7oSUlISk1FZrJBE3TVLdGMcDwoZgTEQxcvozP9uyB7/Bh9F+8CIPZjLSiIuTefTcyS0qgJSWpbpOijOFDMSUiCJw/j0+efRY9bW3AtZ8WDl25At+lS7jS2oq8Bx/ElPJyBtAEx2M+FFNDV6+i/aWX0HPmDCQYxOeBAI5evIgP/X4ERTB09SrO/fd/w3fsGBLwN88UAX7zoZjyHTmC7vffh4igvacHG48fR5vPh3SjET+/8Ub8pLAQuHoV3ro6ZM6ahaTUVNUtU5Twmw/F1JVTp4BgEALg/7a24lRXF4ZE4B8YwB9On8aJzz8HAFz9+GMM9faqbZaiiuFDyvgHBkIe9weDCAwNffFgaAiD14KIJiaGDymhAfg/djuMXzqtfqPFghsyMgAAwUAAXU1NPO4zgfGYD8VUUno6AEDTNFQVFSEzORn/78IFTE1NxUM33ojclBS9tuvwYdjuu4/HfSYohg/FVPbtt6Pz9dchg4MwGgz48fTpuH/6dAx///nyBYYBjwd9HR1Iv/FGNc1SVHG3i2LKPHUq0v793/XHmqbBoGnQri1fFrx6Ff6WFu56TVAMH4qppLQ0ZFy7vXY4Pn/3XchXDkzTxMDwoZib5HCEffVywOPB1X/+M8odkQoMH4q51BtuQOqXdr2uJ9jbi+7WVu56TUAMH4o5zWSCde7csOs/P3gQMjgYxY5IBYYPxZymabDccgs0kyms+r5//Yu7XhMQw4eUSCsqQsq0aWHVSn8//MePc9drgmH4kBKa0YhJt94adr3/2DFIf38UO6JYY/iQEpqmIXPWLBjCvHr56j//id729ih3RbHE8CFl0oqKYLbZwqqVwUF0ud1R7ohiieFDymjJyZh0++1h13efOIGhq1ej2BHFEsOHlNE0DZlz5oS/6/XRRwh4PFHuimKF4UNKpRYUwGy3h1Urg4P4/J13eNZrgmD4kFKGlBRY588Pu97f2oogd70mBIYPKaVpGrIWLoTBbA6rvu/TT7nrNUEwfEg589SpSMnPD6s2GAig69Ah7npNAAwfUi4pLQ2ZJSVh13cdOYJgX18UO6JYiCh8tm3bhpKSElgsFlgsFjgcDuzZs0cf7+vrg8vlQk5ODjIyMlBZWQmv1xvyGu3t7aioqEBaWhpyc3OxYcMGDPJHg99pmqYh+z/+A5oxvIk1h2c4pMQWUfhMmzYNmzdvRnNzM44ePYrvf//7WLJkCU6ePAkAeOyxx/DGG29g586daGxsxPnz57F06VL9+UNDQ6ioqEB/fz8OHjyIHTt2YPv27di0adP4bhUlHLPdHjLD4fUEr16F/733uOuV4DQZ499gdnY2tm7divvvvx9TpkzBK6+8gvvvvx8AcObMGdx8881wu91YuHAh9uzZgx/+8Ic4f/48bNeubH3xxRfxi1/8Ap999hlM3/Ar50AggEAgoD/2+/3Iz8+Hz+eDxWIZS/sUR/61Ywe8//M/YdWmFhZi5tatMIT5y3iKDb/fD6vVGtZnc9THfIaGhlBbW4uenh44HA40NzdjYGAATqdTr5k5cyYKCgrgvnZZvNvtxuzZs/XgAYDy8nL4/X7929NIampqYLVa9SU/zIOTlFg4w+F3S8Th09raioyMDJjNZjzyyCOoq6tDcXExPB4PTCYTsrKyQuptNhs8106NejyekOAZHh8e+ybV1dXw+Xz60sH9/QkpJdIZDk+c4K5XAos4fG666Sa0tLSgqakJq1atQlVVFU6dOhWN3nRms1k/yD280MRjMJlgveWWsOs/f/ddznCYwCIOH5PJhKKiIpSWlqKmpgZz5szBs88+C7vdjv7+fnR1dYXUe71e2K9dPm+327929mv4sT3MS+xp4tI0DZa5cyOb4fCjj6LcFUXLmK/zCQaDCAQCKC0tRXJyMhoaGvSxtrY2tLe3w+FwAAAcDgdaW1vR2dmp19TX18NisaC4uHisrdAEEPEMh7yvV8KK6I6l1dXVuOuuu1BQUIDu7m688sorePvtt/HWW2/BarVi5cqVWL9+PbKzs2GxWLBmzRo4HA4sXLgQALB48WIUFxdj+fLl2LJlCzweD5588km4XC6Yw7y8niY2zWjEJIcDvR9/HFa9v7kZ9qVLofHfT8KJKHw6Ozvx4IMP4sKFC7BarSgpKcFbb72FH/zgBwCAZ555BgaDAZWVlQgEAigvL8cLL7ygPz8pKQm7du3CqlWr4HA4kJ6ejqqqKjz99NPju1WUsDRNQ+bs2TCkpiLY2/ut9Vc//hi9n37KWyonoDFf56NCJNcSUOIJ9vfjzH/9F3o/+SSsevuPf4x/W748uk1RWGJynQ9RtEQ8w2FrK2c4TEAMH4o7o5rh8MKFKHdF443hQ3Ep0hkOLx84wLNeCYbhQ3HJkJIC64IFYddf4QyHCYfhQ3FJ0zRklZWFPcNhL2c4TDgMH4pbKVOnIqWgIKxaznCYeBg+FLcMnOFwQmP4UNzSNA3Zt9/OGQ4nKIYPxbWIZzjkb70SBsOH4lpSejoyZs0Ku/7zd9+FDAxEsSMaLwwfinuTbr01shkOw/xRKqnF8KG4l5KfH/6uF2c4TBgMH4p7SSkpsEQyw+GBA5zhMAEwfCghWObOhZacHFZt4Nw5znCYABg+lBDSiooiuqUy7+sV/xg+lBCGZzgM16W9ezFw6VIUO6KxYvhQQtA0DRmzZ8OQkhJWfb/Hg0v79kW5KxoLhg8ljPSiorCn2QDA4z5xjuFDCSPSGQ4pvjF8KGFomobMkhIY0tLCqjdyfu+4xvChhJI+YwZsS5YA33LFc1J6OqbceWeMuqLRYPhQQtGSkmBbsgTZt98OGEb+52swmzH1P/8TqYWFMe6OIhHRfbuI4kFSWhqm/fznMFosuPT22xjq6QGCQWhGI5InT4a9shI53/8+tG8IJ4oPDB9KSMlWK/5txQrkOJ3o+eADDPX0wJSTg/TiYphychg8CYDhQwnLYDQirbAQady9Skj83wMRKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJcYUPps3b4amaVi3bp2+rq+vDy6XCzk5OcjIyEBlZSW8Xm/I89rb21FRUYG0tDTk5uZiw4YNGOStToi+U0YdPkeOHMFLL72EkpKSkPWPPfYY3njjDezcuRONjY04f/48li5dqo8PDQ2hoqIC/f39OHjwIHbs2IHt27dj06ZNo98KIko8Mgrd3d0yY8YMqa+vlzvuuEPWrl0rIiJdXV2SnJwsO3fu1GtPnz4tAMTtdouIyO7du8VgMIjH49Frtm3bJhaLRQKBwIjv19fXJz6fT186OjoEgPh8vtG0T0RR4vP5wv5sjuqbj8vlQkVFBZxOZ8j65uZmDAwMhKyfOXMmCgoK4Ha7AQButxuzZ8+GzWbTa8rLy+H3+3Hy5MkR36+mpgZWq1Vf8sO8fxMRxa+Iw6e2thbHjh1DTU3N18Y8Hg9MJhOysrJC1ttsNng8Hr3my8EzPD48NpLq6mr4fD596ejoiLRtIoozEc3n09HRgbVr16K+vh4pYd4/aTyYzWaYzeaYvR8RRV9E33yam5vR2dmJuXPnwmg0wmg0orGxEc899xyMRiNsNhv6+/vR1dUV8jyv1wv7tfst2e32r539Gn5sj+CeTESU2CIKn0WLFqG1tRUtLS36Mm/ePCxbtkz/7+TkZDQ0NOjPaWtrQ3t7OxzXbnXrcDjQ2tqKzs5Ovaa+vh4WiwXFxcXjtFlEFO8i2u3KzMzErFmzQtalp6cjJydHX79y5UqsX78e2dnZsFgsWLNmDRwOBxYuXAgAWLx4MYqLi7F8+XJs2bIFHo8HTz75JFwuF3etiL5Dxn0O52eeeQYGgwGVlZUIBAIoLy/HCy+8oI8nJSVh165dWLVqFRwOB9LT01FVVYWnn356vFshojimiYiobiJSfr8fVqsVPp8PFt6VkihuRPLZ5G+7iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlDCqbmA0RAQA4Pf7FXdCRF82/Jkc/oxeT0KGz6VLlwAA+fn5ijshopF0d3fDarVetyYhwyc7OxsA0N7e/q0bGG/8fj/y8/PR0dEBi8Wiup2wse/YStS+RQTd3d3Iy8v71tqEDB+D4YtDVVarNaH+Yr7MYrEkZO/sO7YSse9wvxDwgDMRKcHwISIlEjJ8zGYznnrqKZjNZtWtRCxRe2ffsZWofUdCk3DOiRERjbOE/OZDRImP4UNESjB8iEgJhg8RKcHwISIlEjJ8nn/+eUyfPh0pKSkoKyvD4cOHlfazf/9+3HPPPcjLy4OmaXjttddCxkUEmzZtwtSpU5Gamgqn04kPP/wwpOby5ctYtmwZLBYLsrKysHLlSly5ciWqfdfU1GD+/PnIzMxEbm4u7r33XrS1tYXU9PX1weVyIScnBxkZGaisrITX6w2paW9vR0VFBdLS0pCbm4sNGzZgcHAwan1v27YNJSUl+tW/DocDe/bsieueR7J582ZomoZ169YlXO/jQhJMbW2tmEwm+dOf/iQnT56Uhx56SLKyssTr9Srraffu3fKrX/1K/vrXvwoAqaurCxnfvHmzWK1Wee211+S9996TH/3oR1JYWCi9vb16zZ133ilz5syRQ4cOyTvvvCNFRUXywAMPRLXv8vJyefnll+XEiRPS0tIid999txQUFMiVK1f0mkceeUTy8/OloaFBjh49KgsXLpRbb71VHx8cHJRZs2aJ0+mU48ePy+7du2Xy5MlSXV0dtb7/9re/yd///nf54IMPpK2tTX75y19KcnKynDhxIm57/qrDhw/L9OnTpaSkRNauXauvT4Tex0vChc+CBQvE5XLpj4eGhiQvL09qamoUdvW/vho+wWBQ7Ha7bN26VV/X1dUlZrNZXn31VREROXXqlACQI0eO6DV79uwRTdPk3LlzMeu9s7NTAEhjY6PeZ3JysuzcuVOvOX36tAAQt9stIl8Er8FgEI/Ho9ds27ZNLBaLBAKBmPU+adIk+eMf/5gQPXd3d8uMGTOkvr5e7rjjDj18EqH38ZRQu139/f1obm6G0+nU1xkMBjidTrjdboWdfbOzZ8/C4/GE9Gy1WlFWVqb37Ha7kZWVhXnz5uk1TqcTBoMBTU1NMevV5/MB+N9ZA5qbmzEwMBDS+8yZM1FQUBDS++zZs2Gz2fSa8vJy+P1+nDx5Muo9Dw0Noba2Fj09PXA4HAnRs8vlQkVFRUiPQGL8eY+nhPpV+8WLFzE0NBTyBw8ANpsNZ86cUdTV9Xk8HgAYsefhMY/Hg9zc3JBxo9GI7OxsvSbagsEg1q1bh9tuuw2zZs3S+zKZTMjKyrpu7yNt2/BYtLS2tsLhcKCvrw8ZGRmoq6tDcXExWlpa4rZnAKitrcWxY8dw5MiRr43F8593NCRU+FD0uFwunDhxAgcOHFDdSlhuuukmtLS0wOfz4S9/+QuqqqrQ2Niouq3r6ujowNq1a1FfX4+UlBTV7SiXULtdkydPRlJS0teO/nu9XtjtdkVdXd9wX9fr2W63o7OzM2R8cHAQly9fjsl2rV69Grt27cK+ffswbdo0fb3dbkd/fz+6urqu2/tI2zY8Fi0mkwlFRUUoLS1FTU0N5syZg2effTaue25ubkZnZyfmzp0Lo9EIo9GIxsZGPPfcczAajbDZbHHbezQkVPiYTCaUlpaioaFBXxcMBtHQ0ACHw6Gws29WWFgIu90e0rPf70dTU5Pes8PhQFdXF5qbm/WavXv3IhgMoqysLGq9iQhWr16Nuro67N27F4WFhSHjpaWlSE5ODum9ra0N7e3tIb23traGhGd9fT0sFguKi4uj1vtXBYNBBAKBuO550aJFaG1tRUtLi77MmzcPy5Yt0/87XnuPCtVHvCNVW1srZrNZtm/fLqdOnZKHH35YsrKyQo7+x1p3d7ccP35cjh8/LgDkt7/9rRw/flw+/fRTEfniVHtWVpa8/vrr8v7778uSJUtGPNV+yy23SFNTkxw4cEBmzJgR9VPtq1atEqvVKm+//bZcuHBBX65evarXPPLII1JQUCB79+6Vo0ePisPhEIfDoY8Pn/pdvHixtLS0yJtvvilTpkyJ6qnfJ554QhobG+Xs2bPy/vvvyxNPPCGapsk//vGPuO35m3z5bFei9T5WCRc+IiK///3vpaCgQEwmkyxYsEAOHTqktJ99+/YJgK8tVVVVIvLF6faNGzeKzWYTs9ksixYtkra2tpDXuHTpkjzwwAOSkZEhFotFVqxYId3d3VHte6SeAcjLL7+s1/T29sqjjz4qkyZNkrS0NLnvvvvkwoULIa/zySefyF133SWpqakyefJkefzxx2VgYCBqff/sZz+TG264QUwmk0yZMkUWLVqkB0+89vxNvho+idT7WHE+HyJSIqGO+RDRxMHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKTE/wdKN/X2Ny6vtwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.2374],\n",
       "         [-0.2333]], grad_fn=<TanhBackward0>),\n",
       " tensor([[0.9357],\n",
       "         [0.9717]], grad_fn=<ExpBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class ModelAction(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 64),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(64, 64),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "        self.sigma = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.s(state)\n",
    "        return self.mu(state), self.sigma(state).exp()\n",
    "\n",
    "\n",
    "model_action = ModelAction()\n",
    "\n",
    "model_action(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0019],\n",
       "        [0.0052]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next.load_state_dict(model_value.state_dict())\n",
    "\n",
    "model_value(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "72.2250243315433"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据概率采样\n",
    "        mu, sigma = model_action(torch.FloatTensor(state).reshape(1, 3))\n",
    "        action = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_19891/3624659836.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([0.63902766, 0.76918375, 0.10462173], dtype=float32),\n",
       "  -0.45079428593689985,\n",
       "  0.903496999558007,\n",
       "  array([0.6177828 , 0.7863488 , 0.54627126], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据池\n",
    "class Pool:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.pool = []\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pool)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.pool[i]\n",
    "\n",
    "    #更新动作池\n",
    "    def update(self):\n",
    "        #每次更新不少于N条新数据\n",
    "        old_len = len(self.pool)\n",
    "        while len(pool) - old_len < 200:\n",
    "            self.pool.extend(play()[0])\n",
    "\n",
    "        #只保留最新的N条数据\n",
    "        self.pool = self.pool[-2_0000:]\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def sample(self):\n",
    "        data = random.sample(self.pool, 64)\n",
    "\n",
    "        state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n",
    "        action = torch.FloatTensor([i[1] for i in data]).reshape(-1, 1)\n",
    "        reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)\n",
    "        next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 3)\n",
    "        over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "pool = Pool()\n",
    "pool.update()\n",
    "state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "len(pool), pool[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=5e-4)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=5e-3)\n",
    "\n",
    "\n",
    "def soft_update(_from, _to):\n",
    "    for _from, _to in zip(_from.parameters(), _to.parameters()):\n",
    "        value = _to.data * 0.995 + _from.data * 0.005\n",
    "        _to.data.copy_(value)\n",
    "\n",
    "\n",
    "def get_action_entropy(state):\n",
    "    mu, sigma = model_action(torch.FloatTensor(state).reshape(-1, 3))\n",
    "    dist = torch.distributions.Normal(mu, sigma)\n",
    "\n",
    "    action = dist.rsample()\n",
    "    entropy = dist.log_prob(action) - (1 - action.tanh()**2 + 1e-8).log()\n",
    "    entropy = -entropy\n",
    "\n",
    "    return action, entropy\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4308680295944214"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, action, reward, next_state, over):\n",
    "    requires_grad(model_value, True)\n",
    "    requires_grad(model_action, False)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        #计算动作和熵\n",
    "        next_action, entropy = get_action_entropy(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        input = torch.cat([next_state, next_action], dim=1)\n",
    "        target = model_value_next(input)\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    target = target + 5e-3 * entropy\n",
    "    target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_value.step()\n",
    "    optimizer_value.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_value(state, action, reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.4163397550582886"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state):\n",
    "    requires_grad(model_value, False)\n",
    "    requires_grad(model_action, True)\n",
    "\n",
    "    #计算action和熵\n",
    "    action, entropy = get_action_entropy(state)\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    loss = -(value + 5e-3 * entropy).mean()\n",
    "\n",
    "    #使用model_value计算model_action的loss\n",
    "    loss.backward()\n",
    "    optimizer_action.step()\n",
    "    optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 33.43052610713056\n",
      "10 2400 117.36476210537043\n",
      "20 4400 180.2052989746231\n",
      "30 6400 176.09416724964098\n",
      "40 8400 177.69048961029395\n",
      "50 10400 181.38156840913834\n",
      "60 12400 180.11599751167884\n",
      "70 14400 182.01393661808527\n",
      "80 16400 178.62870999844935\n",
      "90 18400 176.47998019033346\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(100):\n",
    "        #更新N条数据\n",
    "        pool.update()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "            #训练\n",
    "            train_value(state, action, reward, next_state, over)\n",
    "            train_action(state)\n",
    "            soft_update(model_value, model_value_next)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, len(pool), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbsElEQVR4nO3dfXBU9b0G8Gc3+5LXsyFAdklJSm5BIcNLNWDYeqe2JSXa1EqNdyzD2JRydaALl5cOU9IKTpnOhIsztVIVO+MI/IPpxVuwpqDmBgx1WAEjqeEt1Tu0icBuEMxuEsi+fu8fknNdjJgN2f1l1+czc2bY8/tt8hyYfTh7zp6zBhEREBElmVF1ACL6cmL5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREsrK59lnn8WUKVOQmZmJiooKHDt2TFUUIlJASfn88Y9/xLp16/DEE0/g3XffxZw5c1BVVYXu7m4VcYhIAYOKC0srKiowb948PPPMMwCAaDSK4uJirFq1Chs2bPjC50ejUVy4cAF5eXkwGAyJjktEwyQi6O3tRVFREYzGm+/bmJKUSRcMBtHa2oq6ujp9ndFoRGVlJdxu95DPCQQCCAQC+uPz58+jrKws4VmJaGS6urowefLkm85Jevl89NFHiEQisNvtMevtdjvOnj075HPq6+vx61//+jPru7q6oGlaQnISUfz8fj+Ki4uRl5f3hXOTXj4jUVdXh3Xr1umPBzdQ0zSWD9EYNJzDIUkvnwkTJiAjIwNerzdmvdfrhcPhGPI5VqsVVqs1GfGIKEmSfrbLYrGgvLwczc3N+rpoNIrm5mY4nc5kxyEiRZS87Vq3bh1qa2sxd+5c3HXXXfjd736H/v5+LF26VEUcIlJASfk8/PDDuHTpEjZt2gSPx4Ovf/3reO211z5zEJqI0peSz/ncKr/fD5vNBp/PxwPORGNIPK9NXttFREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpEXf5HD58GPfffz+KiopgMBiwb9++mHERwaZNmzBp0iRkZWWhsrIS77//fsycK1euYMmSJdA0Dfn5+Vi2bBn6+vpuaUOIKLXEXT79/f2YM2cOnn322SHHt27dim3btuH555/H0aNHkZOTg6qqKgwMDOhzlixZglOnTqGpqQmNjY04fPgwHnvssZFvBRGlHrkFAGTv3r3642g0Kg6HQ5588kl9XU9Pj1itVnnppZdEROT06dMCQI4fP67POXDggBgMBjl//vywfq/P5xMA4vP5biU+EY2yeF6bo3rM59y5c/B4PKisrNTX2Ww2VFRUwO12AwDcbjfy8/Mxd+5cfU5lZSWMRiOOHj065M8NBALw+/0xCxGltlEtH4/HAwCw2+0x6+12uz7m8XhQWFgYM24ymVBQUKDPuVF9fT1sNpu+FBcXj2ZsIlIgJc521dXVwefz6UtXV5fqSER0i0a1fBwOBwDA6/XGrPd6vfqYw+FAd3d3zHg4HMaVK1f0OTeyWq3QNC1mIaLUNqrlU1paCofDgebmZn2d3+/H0aNH4XQ6AQBOpxM9PT1obW3V5xw8eBDRaBQVFRWjGYeIxjBTvE/o6+vDBx98oD8+d+4c2traUFBQgJKSEqxZswa/+c1vMG3aNJSWlmLjxo0oKirCokWLAAAzZszAvffei0cffRTPP/88QqEQVq5ciR/96EcoKioatQ0jojEu3lNphw4dEgCfWWpra0Xkk9PtGzduFLvdLlarVRYsWCAdHR0xP+Py5cuyePFiyc3NFU3TZOnSpdLb2zvsDDzVTjQ2xfPaNIiIKOy+EfH7/bDZbPD5fDz+QzSGxPPaTImzXUSUflg+RKQEy4eIlIj7bBdRIogIIv39CFy8iOjAADJycmAtKoLRaoXBYFAdjxKA5UPKRUMhfHzkCLpfeQUDH36IaCCAjKwsZH3ta5j00EPImzMHBiN30tMNy4eUknAY3Y2NuPjSS4h+6rYrkatX0dfejnOdnSh+7DGM+9d/5R5QmuF/J6SMiMD/3nvw/Nd/xRTPp4V9Pnz44osY4PV8aYflQ+pEIuhubESkv/+m00KXL+PSgQNIwY+k0U2wfEgJEUFfRwf6Tp0a1vyBDz9McCJKNpYPKTPQ2YnotWuqY5AiLB9SJswvDfhSY/mQMmHeDvdLjeVDaoigv6Nj2NMNJn4qJN2wfEgZCYWGPTdv5swEJiEVWD6UEkw2m+oINMpYPqSERCKQaHTY8015eQlMQyqwfEiJaDAY19suo8XCyyvSDMuHlIgODCAaR/lQ+mH5kBKhjz9GuLdXdQxSiOVDSoR7exG9enVYczNycnjAOQ2xfGjMy8jLg3ncONUxaJSxfCjpRASI4wp1o9kMY2ZmAhORCiwfUuLz7t8zFIPZDKPFksA0pALLh5SI52CzwWAAeJo97bB8SImwz6c6AinG8iElhnsTMQAw8C1XWmL5kBISiQx7bu6MGXzblYZYPjTmmb7gO78pNbF8KPlEeFEpsXwo+SQUQjQYHPZ8o9WawDSkCsuHki4aDEICgbiewyva0w/Lh5Iu7PMheOWK6hikGMuHki5y9Soiw/zmCmNWFswTJiQ4EanA8qExLSMzE5aCAtUxKAFYPjSmGUwmGLOyVMegBIirfOrr6zFv3jzk5eWhsLAQixYtQscNX38yMDAAl8uF8ePHIzc3FzU1NfB6vTFzOjs7UV1djezsbBQWFmL9+vUIh8O3vjWUEqKBwLCvajeYTLyiPU3FVT4tLS1wuVx4++230dTUhFAohIULF6K/v1+fs3btWrz66qvYs2cPWlpacOHCBTz44IP6eCQSQXV1NYLBII4cOYJdu3Zh586d2LRp0+htFY1pcX1TqcHAM11pyiASx41VbnDp0iUUFhaipaUF3/zmN+Hz+TBx4kTs3r0bDz30EADg7NmzmDFjBtxuN+bPn48DBw7g+9//Pi5cuAC73Q4AeP755/GLX/wCly5dgmUY1/H4/X7YbDb4fD5o/PRryrn0xhvofOaZYc21FhWhbNs23lIjRcTz2rylYz6+61cmF1w/INja2opQKITKykp9zvTp01FSUgK32w0AcLvdmDVrll48AFBVVQW/349Tn3OxYSAQgN/vj1kodV07d27Yc41WK6/rSlMjLp9oNIo1a9bg7rvvxszr3ybp8XhgsViQn58fM9dut8Pj8ehzPl08g+ODY0Opr6+HzWbTl+Li4pHGpjEg9PHHw56bc9ttMBh5XiQdjfhf1eVy4eTJk2hoaBjNPEOqq6uDz+fTl66uroT/ThobMnJzueeTpkwjedLKlSvR2NiIw4cPY/Lkyfp6h8OBYDCInp6emL0fr9cLh8Ohzzl27FjMzxs8GzY450ZWqxVWXt+TFuK9f7OJ5ZO24trzERGsXLkSe/fuxcGDB1FaWhozXl5eDrPZjObmZn1dR0cHOjs74XQ6AQBOpxPt7e3o7u7W5zQ1NUHTNJSVld3KtlAKkHD4k1Ptw8SLStNXXHs+LpcLu3fvxiuvvIK8vDz9GI3NZkNWVhZsNhuWLVuGdevWoaCgAJqmYdWqVXA6nZg/fz4AYOHChSgrK8MjjzyCrVu3wuPx4PHHH4fL5eLezZeAhMNx3Tyep9rTV1zls337dgDAt771rZj1O3bswE9+8hMAwFNPPQWj0YiamhoEAgFUVVXhueee0+dmZGSgsbERK1asgNPpRE5ODmpra7F58+Zb2xJKCdFgkN9USgBu8XM+qvBzPqlr4MMPcfo//gMyjE+0GywWfG3DBtjmzk1CMhoNSfucD1EiGc1mXtGexlg+NGYZMjKQkZ2tOgYlCMuHkioaCg3/VHtGBjJ4UWnaYvlQUkX6+zHcg4wGAMjISGAaUonlQ0kV7uuL60OGlL5YPpRUVz/4ABjm1+YYMzP5GZ80xvKhpArHcUeCrH/5FxjM5gSmIZVYPjRmmXJyeEV7GuO/LCVNvJ9nzcjJ4UWlaYzlQ8kTjcZ/USnLJ22xfChpJBJB5OrV4T+BF5WmNZYPJY2EQgj39KiOQWPEiG4mRjQSkYEBXOvs1B+LCHqCQfxvby9sFgu+lpcH4/U9HYPJhMxP3aiO0g/Lh5QQEXT292PjiRPo8PmQYzLh32+7DQ+XliLDYACMRpjHj1cdkxKIb7tICQHwn+3tON3Tg4gI/KEQnjlzBiev31zeYDTClJOjNiQlFMuHkiYjKwtZU6boj/2hUMx4MBpFIBL55IHBwCva0xzLh5ImIzsbudOnA/jkotFvOxwwfeps1m2ahq/m5n4yNzOTn25OczzmQ0k1/tvfxkdNTYj09aF26lTkmc34n4sXMSkrC4/edhsKr99CY9w998DEu1SmNZYPJVVmSQkmLV6M87t2wRQM4t+mTMFDU6ZgcP/HYDAgt6wM9kWLeGlFmmP5UFIZjEZMvPdeAID3v/8boY8/huH6ZRcGsxnaHXdg8rJlsFz/Cm5KXywfSjqj2YzC6mrY7rwT/hMnEPB4YMzKQu6MGcidMQMZWVmqI1ISsHxICYPRiMyvfAWZX/mK6iikCN9UE5ESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIl4iqf7du3Y/bs2dA0DZqmwel04sCBA/r4wMAAXC4Xxo8fj9zcXNTU1MDr9cb8jM7OTlRXVyM7OxuFhYVYv349wuHw6GwNEaWMuMpn8uTJ2LJlC1pbW/HOO+/gO9/5Dh544AGcOnUKALB27Vq8+uqr2LNnD1paWnDhwgU8+OCD+vMjkQiqq6sRDAZx5MgR7Nq1Czt37sSmTZtGd6uIaOyTWzRu3Dh54YUXpKenR8xms+zZs0cfO3PmjAAQt9stIiL79+8Xo9EoHo9Hn7N9+3bRNE0CgcDn/o6BgQHx+Xz60tXVJQDE5/PdanwiGkU+n2/Yr80RH/OJRCJoaGhAf38/nE4nWltbEQqFUFlZqc+ZPn06SkpK4Ha7AQButxuzZs2C3W7X51RVVcHv9+t7T0Opr6+HzWbTl+Li4pHGJqIxIu7yaW9vR25uLqxWK5YvX469e/eirKwMHo8HFosF+fn5MfPtdjs8Hg8AwOPxxBTP4Pjg2Oepq6uDz+fTl66urnhjE9EYE/c9nG+//Xa0tbXB5/Ph5ZdfRm1tLVpaWhKRTWe1WmG1WhP6O4goueIuH4vFgqlTpwIAysvLcfz4cTz99NN4+OGHEQwG0dPTE7P34/V64XA4AAAOhwPHjh2L+XmDZ8MG5xDRl8Mtf84nGo0iEAigvLwcZrMZzc3N+lhHRwc6OzvhdDoBAE6nE+3t7eju7tbnNDU1QdM0lJWV3WoUIkohce351NXV4b777kNJSQl6e3uxe/duvPnmm3j99ddhs9mwbNkyrFu3DgUFBdA0DatWrYLT6cT8+fMBAAsXLkRZWRkeeeQRbN26FR6PB48//jhcLhffVhF9ycRVPt3d3fjxj3+MixcvwmazYfbs2Xj99dfx3e9+FwDw1FNPwWg0oqamBoFAAFVVVXjuuef052dkZKCxsRErVqyA0+lETk4OamtrsXnz5tHdKiIa8wwi17+rNoX4/X7YbDb4fD5omqY6DhFdF89rk9d2EZESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNEStxS+WzZsgUGgwFr1qzR1w0MDMDlcmH8+PHIzc1FTU0NvF5vzPM6OztRXV2N7OxsFBYWYv369QiHw7cShYhSzIjL5/jx4/jDH/6A2bNnx6xfu3YtXn31VezZswctLS24cOECHnzwQX08EomguroawWAQR44cwa5du7Bz505s2rRp5FtBRKlHRqC3t1emTZsmTU1Ncs8998jq1atFRKSnp0fMZrPs2bNHn3vmzBkBIG63W0RE9u/fL0ajUTwejz5n+/btommaBAKBIX/fwMCA+Hw+fenq6hIA4vP5RhKfiBLE5/MN+7U5oj0fl8uF6upqVFZWxqxvbW1FKBSKWT99+nSUlJTA7XYDANxuN2bNmgW73a7Pqaqqgt/vx6lTp4b8ffX19bDZbPpSXFw8kthENIbEXT4NDQ149913UV9f/5kxj8cDi8WC/Pz8mPV2ux0ej0ef8+niGRwfHBtKXV0dfD6fvnR1dcUbm4jGGFM8k7u6urB69Wo0NTUhMzMzUZk+w2q1wmq1Ju33EVHixbXn09raiu7ubtx5550wmUwwmUxoaWnBtm3bYDKZYLfbEQwG0dPTE/M8r9cLh8MBAHA4HJ85+zX4eHAOEaW/uMpnwYIFaG9vR1tbm77MnTsXS5Ys0f9sNpvR3NysP6ejowOdnZ1wOp0AAKfTifb2dnR3d+tzmpqaoGkaysrKRmmziGisi+ttV15eHmbOnBmzLicnB+PHj9fXL1u2DOvWrUNBQQE0TcOqVavgdDoxf/58AMDChQtRVlaGRx55BFu3boXH48Hjjz8Ol8vFt1ZEXyJxlc9wPPXUUzAajaipqUEgEEBVVRWee+45fTwjIwONjY1YsWIFnE4ncnJyUFtbi82bN492FCIawwwiIqpDxMvv98Nms8Hn80HTNNVxiOi6eF6bvLaLiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZPqACMhIgAAv9+vOAkRfdrga3LwNXozKVk+ly9fBgAUFxcrTkJEQ+nt7YXNZrvpnJQsn4KCAgBAZ2fnF27gWOP3+1FcXIyuri5omqY6zrAxd3Klam4RQW9vL4qKir5wbkqWj9H4yaEqm82WUv8wn6ZpWkpmZ+7kSsXcw90h4AFnIlKC5UNESqRk+VitVjzxxBOwWq2qo8QtVbMzd3Klau54GGQ458SIiEZZSu75EFHqY/kQkRIsHyJSguVDREqwfIhIiZQsn2effRZTpkxBZmYmKioqcOzYMaV5Dh8+jPvvvx9FRUUwGAzYt29fzLiIYNOmTZg0aRKysrJQWVmJ999/P2bOlStXsGTJEmiahvz8fCxbtgx9fX0JzV1fX4958+YhLy8PhYWFWLRoETo6OmLmDAwMwOVyYfz48cjNzUVNTQ28Xm/MnM7OTlRXVyM7OxuFhYVYv349wuFwwnJv374ds2fP1j/963Q6ceDAgTGdeShbtmyBwWDAmjVrUi77qJAU09DQIBaLRV588UU5deqUPProo5Kfny9er1dZpv3798uvfvUr+dOf/iQAZO/evTHjW7ZsEZvNJvv27ZO//e1v8oMf/EBKS0vl2rVr+px7771X5syZI2+//bb89a9/lalTp8rixYsTmruqqkp27NghJ0+elLa2Nvne974nJSUl0tfXp89Zvny5FBcXS3Nzs7zzzjsyf/58+cY3vqGPh8NhmTlzplRWVsqJEydk//79MmHCBKmrq0tY7j//+c/yl7/8Rf7+979LR0eH/PKXvxSz2SwnT54cs5lvdOzYMZkyZYrMnj1bVq9era9PheyjJeXK56677hKXy6U/jkQiUlRUJPX19QpT/b8byycajYrD4ZAnn3xSX9fT0yNWq1VeeuklERE5ffq0AJDjx4/rcw4cOCAGg0HOnz+ftOzd3d0CQFpaWvScZrNZ9uzZo885c+aMABC32y0inxSv0WgUj8ejz9m+fbtomiaBQCBp2ceNGycvvPBCSmTu7e2VadOmSVNTk9xzzz16+aRC9tGUUm+7gsEgWltbUVlZqa8zGo2orKyE2+1WmOzznTt3Dh6PJyazzWZDRUWFntntdiM/Px9z587V51RWVsJoNOLo0aNJy+rz+QD8/10DWltbEQqFYrJPnz4dJSUlMdlnzZoFu92uz6mqqoLf78epU6cSnjkSiaChoQH9/f1wOp0pkdnlcqG6ujomI5Aaf9+jKaWuav/oo48QiURi/uIBwG634+zZs4pS3ZzH4wGAITMPjnk8HhQWFsaMm0wmFBQU6HMSLRqNYs2aNbj77rsxc+ZMPZfFYkF+fv5Nsw+1bYNjidLe3g6n04mBgQHk5uZi7969KCsrQ1tb25jNDAANDQ149913cfz48c+MjeW/70RIqfKhxHG5XDh58iTeeust1VGG5fbbb0dbWxt8Ph9efvll1NbWoqWlRXWsm+rq6sLq1avR1NSEzMxM1XGUS6m3XRMmTEBGRsZnjv57vV44HA5FqW5uMNfNMjscDnR3d8eMh8NhXLlyJSnbtXLlSjQ2NuLQoUOYPHmyvt7hcCAYDKKnp+em2YfatsGxRLFYLJg6dSrKy8tRX1+POXPm4Omnnx7TmVtbW9Hd3Y0777wTJpMJJpMJLS0t2LZtG0wmE+x2+5jNnggpVT4WiwXl5eVobm7W10WjUTQ3N8PpdCpM9vlKS0vhcDhiMvv9fhw9elTP7HQ60dPTg9bWVn3OwYMHEY1GUVFRkbBsIoKVK1di7969OHjwIEpLS2PGy8vLYTabY7J3dHSgs7MzJnt7e3tMeTY1NUHTNJSVlSUs+42i0SgCgcCYzrxgwQK0t7ejra1NX+bOnYslS5bofx6r2RNC9RHveDU0NIjVapWdO3fK6dOn5bHHHpP8/PyYo//J1tvbKydOnJATJ04IAPntb38rJ06ckH/+858i8smp9vz8fHnllVfkvffekwceeGDIU+133HGHHD16VN566y2ZNm1awk+1r1ixQmw2m7z55pty8eJFfbl69ao+Z/ny5VJSUiIHDx6Ud955R5xOpzidTn188NTvwoULpa2tTV577TWZOHFiQk/9btiwQVpaWuTcuXPy3nvvyYYNG8RgMMgbb7wxZjN/nk+f7Uq17Lcq5cpHROT3v/+9lJSUiMVikbvuukvefvttpXkOHTokAD6z1NbWisgnp9s3btwodrtdrFarLFiwQDo6OmJ+xuXLl2Xx4sWSm5srmqbJ0qVLpbe3N6G5h8oMQHbs2KHPuXbtmvzsZz+TcePGSXZ2tvzwhz+Uixcvxvycf/zjH3LfffdJVlaWTJgwQX7+859LKBRKWO6f/vSn8tWvflUsFotMnDhRFixYoBfPWM38eW4sn1TKfqt4Px8iUiKljvkQUfpg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJS4v8A5Fo1rzjcANwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "183.81251913645426"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
